
# coding: utf-8

# In[2]:

#My name: Joe Puccio
#Collaborators: Fred Landis, Alan Wu, Zen Yang

#Helper functions courtesy of Master McMillan
import string
import array as arrayModule
import sys
from numpy import *

# The following classes are used to construct 
# and represent Suffix trees

class Edge:
    count = 0
    def __init__(self, dstNode, first, last):
        self.dstNode = dstNode
        self.first = first
        self.last = last
        Edge.count += 1
    def split(self, suffix, suffixTree):
        # add new explicit node
        newIndex = len(suffixTree.nodes)
        suffixTree.nodes.append(0)
        # add new suffix edge
        newFirst = self.first + len(suffix)
        suffixTree.edgeLookup[newIndex, suffixTree.string[newFirst]] = Edge(self.dstNode, newFirst, self.last)
        # shorten this edge
        self.last = newFirst - 1
        self.dstNode = newIndex
        return newIndex
    def isLeafEdge(self, suffixTree):
        return (self.last == suffixTree.lastIndex)
    def __len__(self):
        return self.last - self.first + 1
    def __repr__(self):
        return "Edge(%d, %d, %d)" % (self.dstNode, self.first, self.last)

class Suffix:
    def __init__(self, srcNode, first, last):
        self.srcNode = srcNode
        self.first = first
        self.last = last
    def is_explicit(self):
        return self.first > self.last
    def is_implicit(self):
        return self.first <= self.last
    def canonicalize(self, suffixTree):
        if self.is_implicit():
            edge = suffixTree.edgeLookup[self.srcNode, suffixTree.string[self.first]]
            if (len(edge) <= len(self)):
                self.first += len(edge)
                self.srcNode = edge.dstNode
                self.canonicalize(suffixTree)
    def __len__(self):
        return self.last - self.first + 1
    def __repr__(self):
        return "Suffix(%d, %d, %d)" % (self.srcNode, self.first, self.last)

class SuffixTree:
    def __init__(self, string):
        self.string = string            # save a pointer to the string
        self.alphabet = set()           # alphabet of string
        self.nodes = arrayModule.array('l')   # index of ith node's parent
        self.nodes.append(0)            # add root node
        self.edgeLookup = {}            # adjacency list indexed by (srcNode, char)
        self.lastIndex = len(string) - 1
        activePoint = Suffix(0, 0, -1)
        for i in xrange(len(string)):
            self.alphabet.add(string[i])
            self.addPrefix(i, activePoint)
    def addPrefix(self, last, activePoint):
        LastParentNode = -1
        while True:
            ParentNode = activePoint.srcNode
            if activePoint.is_explicit():
                if (activePoint.srcNode, self.string[last]) in self.edgeLookup:
                    break
            else:               #potentially split an implicit node
                edge = self.edgeLookup[activePoint.srcNode, self.string[activePoint.first]]
                if (self.string[edge.first + len(activePoint)] == self.string[last]):
                    break
                else:
                    ParentNode = edge.split(activePoint, self)
            self.nodes.append(-1)
            self.edgeLookup[ParentNode, self.string[last]] = Edge(len(self.nodes)-1, last, self.lastIndex)
            # add suffix link
            if (LastParentNode > 0):
                self.nodes[LastParentNode] = ParentNode
            LastParentNode = ParentNode
            if (activePoint.srcNode == 0):
                activePoint.first += 1
            else:
                activePoint.srcNode = self.nodes[activePoint.srcNode]
            activePoint.canonicalize(self)
        if (LastParentNode > 0):
            self.nodes[LastParentNode] = ParentNode
        activePoint.last += 1
        activePoint.canonicalize(self)
    def leafIndices(self, nodeIndex=0, lenSoFar=0):
        indexList = []
        if (self.nodes[nodeIndex] < 0):
            indexList.append(self.lastIndex + 1 - lenSoFar)
        else:
            for char in self.alphabet:
                try:
                    edge = self.edgeLookup[nodeIndex, char]
                    if edge.isLeafEdge(self):
                        indexList.append(self.lastIndex + 1 - len(edge) - lenSoFar)
                    else:
                        indexList += self.leafIndices(edge.dstNode, lenSoFar + len(edge))
                except KeyError:
                    continue
        return indexList
    def distinct(self, nodeIndex=0, lenSoFar=0):
        distinctList = []
        # examine children of node
        for char in self.alphabet:
            try:
                edge = self.edgeLookup[nodeIndex, char]
                if edge.isLeafEdge(self):
                    distinctList.append((edge.first-lenSoFar, lenSoFar+1))
                else:
                    distinctList += self.distinct(edge.dstNode, lenSoFar + len(edge))
            except:
                continue
        return distinctList
    def thread(self, target):
        nodeIndex = 0           # start from root
        i = 0                   # characters threaded so far
        while (i < len(target)):
            try:
                edge = self.edgeLookup[nodeIndex, target[i]]
                prefix = self.string[edge.first:edge.last+1]
                if (target[i:].startswith(prefix) or prefix.startswith(target[i:])):
                    i += len(prefix)
                else:
                    return []
                nodeIndex = edge.dstNode
            except KeyError:
                return []
        return self.leafIndices(nodeIndex, i)
    def printTree(self, nodeIndex=0, prefixLength=0, prefix=""):
        branches = list()
        for c in self.alphabet:
            try:
                edge = self.edgeLookup[nodeIndex, c]
                extent = prefix + "(%d)-%s-" % (nodeIndex, self.string[edge.first:edge.last+1])
                subtree = self.printTree(edge.dstNode, prefixLength+len(edge), extent)
                if (len(subtree) > 0):
                    branches.append(subtree)
                elif (edge.isLeafEdge(self)):
                    print extent + "(%d)[%d]" % (edge.dstNode, edge.first-prefixLength)
            except KeyError:
                pass
        return branches
    
# a one-line function to compute a Suffix array
def argsort(text):
    return sorted(range(len(text)), cmp=lambda i,j: -1 if text[i:] < text[j:] else 1)

def findFirst(pattern, text, sfa):
    """ Finds the index of the first occurence of pattern in the suffix array """
    hi = len(text)
    lo = 0
    while (lo < hi):
        mid = (lo+hi)//2
        if (pattern > text[sfa[mid]:]):
            lo = mid + 1
        else:
            hi = mid
    return lo

def findLast(pattern, text, sfa):
    """ Finds the index of the last occurence of pattern in the suffix array """
    hi = len(text)
    lo = 0
    m = len(pattern)
    while (lo < hi):
        mid = (lo+hi)//2
        i = sfa[mid]
        if (pattern >= text[i:i+m]):
            lo = mid + 1
        else:
            hi = mid
    return lo


# In[3]:

chr1 = open("HumChr01.seq",'rb').read()
import sys
sys.setrecursionlimit(10000)


# In[4]:

#Problem 1

import time 

N = 100
treeTime = 0.0
while (N <= 100000000) and (treeTime < 100.0):
        print "%10d, " % N,
        text = chr1[10000:10000+N]
        start = time.clock()
        try:
            sTree= SuffixTree(text)
        except RuntimeError:
            print "Crashed"
            break
        treeTime = time.clock() - start
        print "%8d, %6.3f secs" % (len(sTree.nodes), treeTime),
        start = time.clock()
        catsfound = sTree.thread("CAT")
        findTime = time.clock() - start
        print "%8d, %6.3f secs" % (len(catsfound), findTime)
        N*=10
            


# In[11]:

#Problem 2
N = 100
arrayTime = 0.0
while (N <= 100000000) and (arrayTime < 100.0):
        print "%10d, " % N,
        text = chr1[10000:10000+N]
        start = time.clock()
        sArray = argsort(text)
        arrayTime = time.clock() - start
        print "%8d, %6.3f secs" % (len(sArray), arrayTime),
        start = time.clock()
        lo = findFirst("CAT", text, sArray)
        hi = findLast("CAT", text, sArray)
        findTime = time.clock() - start
        print "%8d, %6.3f secs" % (hi-lo, findTime)
        N *=10
          


# In[5]:

#Problem 3
#import matplotlib.pyplot as plt

class SuperTree(SuffixTree):
    def getSuffixArray(self, nodeIndex=0, prefixLength=0):
        if (nodeIndex == 0):
            self.suffixArray = []
        for c in sorted(self.alphabet):
            try:
                edge = self.edgeLookup[nodeIndex, c]
                if (edge.isLeafEdge(self)):
                    self.suffixArray.append(edge.first-prefixLength)
                else:
                    self.getSuffixArray(edge.dstNode, prefixLength+len(edge))
            except KeyError:
                pass
        return self.suffixArray

    
N = 10000
arrayTime = 0.0
arrayTimeArray = []
treeTimeArray = []
NArray = []

while(arrayTime < 1.0):
    NArray.append(N)
    print "%10d, " % N, 
    text = chr1[10000:10000+N]
    start = time.clock()
    sArrray = argsort(text)
    arrayTime = time.clock() - start
    arrayTimeArray.append(arrayTime)
    print "%6.3f secs" % (arrayTime), 
    start = time.clock()
    sTree = SuperTree(text)
    tArray = sTree.getSuffixArray()
    treeTime = time.clock() - start
    treeTimeArray.append(treeTime)
    print "%6.3f secs" % (treeTime)
    N+=1000

#plt.plot(NArray, arrayTimeArray, 'bs', NArray, treeTimeArray, 'g^')


# In[ ]:



